/* Copyright 2017 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

export type Vector = number[] | Float64Array | Float32Array | Int32Array |
    Int8Array | Int16Array;

/** Shuffles the array using Fisher-Yates algorithm. */
// tslint:disable-next-line:no-any
export function shuffle(array: any[]|Uint32Array|Int32Array|
                        Float32Array): void {
  let counter = array.length;
  let temp = 0;
  let index = 0;
  // While there are elements in the array
  while (counter > 0) {
    // Pick a random index
    index = (Math.random() * counter) | 0;
    // Decrease counter by 1
    counter--;
    // And swap the last element with it
    temp = array[counter];
    array[counter] = array[index];
    array[index] = temp;
  }
}

/** Clamps a value to a specified range. */
export function clamp(min: number, x: number, max: number): number {
  return Math.max(min, Math.min(x, max));
}

/** Returns a sample from a uniform [a, b] distribution. */
export function randUniform(a: number, b: number) {
  return Math.random() * (b - a) + a;
}

/**
 * Samples from a gaussian distribution.
 *
 * @param mean The mean. Default is 0.
 * @param stdDev The standard deviation. Default is 1.
 */
export function randGauss(mean = 0, stdDev = 1, truncated = false): number {
  let v1: number, v2: number, s: number;
  do {
    v1 = 2 * Math.random() - 1;
    v2 = 2 * Math.random() - 1;
    s = v1 * v1 + v2 * v2;
  } while (s > 1);

  const result = Math.sqrt(-2 * Math.log(s) / s) * v1;
  if (truncated && result > 2) {
    return randGauss(mean, stdDev, true);
  }
  return mean + stdDev * result;
}

/** Returns squared eucledian distance between two vectors. */
export function distSquared(a: Vector, b: Vector): number {
  let result = 0;
  for (let i = 0; i < a.length; i++) {
    const diff = a[i] - b[i];
    result += diff * diff;
  }
  return result;
}

export function assert(expr: boolean, msg: string) {
  if (!expr) {
    throw new Error(msg);
  }
}

export function assertShapesMatch(
    shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void {
  assert(
      arraysEqual(shapeA, shapeB),
      errorMessagePrefix + `Shapes ${shapeA} and ${shapeB} must match`);
}

// tslint:disable-next-line:no-any
export function flatten(arr: any[], ret?: number[]): number[] {
  ret = (ret === undefined ? [] : ret);
  for (let i = 0; i < arr.length; ++i) {
    if (Array.isArray(arr[i])) {
      flatten(arr[i], ret);
    } else {
      ret.push(arr[i]);
    }
  }
  return ret;
}

export type ArrayData = number|number[]|number[][]|number[][][]|number[][][][];

export function inferShape(arr: ArrayData): number[] {
  const shape: number[] = [];
  while (arr instanceof Array) {
    shape.push(arr.length);
    arr = arr[0];
  }
  return shape;
}

export function sizeFromShape(shape: number[]): number {
  if (shape.length === 0) {
    // Scalar.
    return 1;
  }
  let size = shape[0];
  for (let i = 1; i < shape.length; i++) {
    size *= shape[i];
  }
  return size;
}

export function isScalarShape(shape: number[]): boolean {
  return shape.length === 0;
}

// tslint:disable-next-line:no-any
export function arraysEqual(n1: any[]|Float32Array, n2: any[]|Float32Array) {
  if (n1.length !== n2.length) {
    return false;
  }
  for (let i = 0; i < n1.length; i++) {
    if (n1[i] !== n2[i]) {
      return false;
    }
  }
  return true;
}

export function isInt(a: number): boolean {
  return a % 1 === 0;
}

export function tanh(x: number): number {
  // tslint:disable-next-line:no-any
  if ((Math as any).tanh != null) {
    // tslint:disable-next-line:no-any
    return (Math as any).tanh(x);
  }
  if (x === Infinity) {
    return 1;
  } else if (x === -Infinity) {
    return -1;
  } else {
    const e2x = Math.exp(2 * x);
    return (e2x - 1) / (e2x + 1);
  }
}

export function sizeToSquarishShape(size: number): [number, number] {
  for (let a = Math.floor(Math.sqrt(size)); a > 1; --a) {
    if (size % a === 0) {
      return [a, size / a];
    }
  }
  return [1, size];
}

export function createShuffledIndices(n: number): Uint32Array {
  const shuffledIndices = new Uint32Array(n);
  for (let i = 0; i < n; ++i) {
    shuffledIndices[i] = i;
  }
  shuffle(shuffledIndices);
  return shuffledIndices;
}

export function assertAndGetBroadcastedShape(
    shapeA: number[], shapeB: number[]): number[] {
  const result: number[] = [];
  let nextADimMustBeOne = false;
  let nextBDimMustBeOne = false;
  const errMsg = `Operands could not be broadcast together with shapes ` +
      `${shapeA} and ${shapeB}. Currently, we only support a ` +
      `stricter version of broadcasting than numpy.`;
  const l = Math.max(shapeA.length, shapeB.length);

  shapeA = shapeA.slice().reverse();
  shapeB = shapeB.slice().reverse();
  for (let i = 0; i < l; i++) {
    const a = shapeA[i] || 1;
    const b = shapeB[i] || 1;
    if ((b > 1 && nextBDimMustBeOne) || (a > 1 && nextADimMustBeOne)) {
      throw Error(errMsg);
    }
    if (a > 1 && b === 1) {
      nextBDimMustBeOne = true;
    }
    if (b > 1 && a === 1) {
      nextADimMustBeOne = true;
    }
    if (a > 1 && b > 1 && a !== b) {
      throw Error(errMsg);
    }
    result.push(Math.max(a, b));
  }
  return result.reverse();
}
